{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quick Start PyTorch - MNIST\n",
    "\n",
    "To run a Code Cell you can click on the `⏯ Run` button in the Navigation Bar above or type `Shift + Enter`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.utils.data.dataloader as dataloader\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch.utils.data import TensorDataset\n",
    "from torch.autograd import Variable\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "SEED = 1\n",
    "\n",
    "# CUDA?\n",
    "cuda = torch.cuda.is_available()\n",
    "\n",
    "# For reproducibility\n",
    "torch.manual_seed(SEED)\n",
    "\n",
    "if cuda:\n",
    "    torch.cuda.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "train = MNIST('./data', train=True, download=True, transform=transforms.Compose([\n",
    "    transforms.ToTensor(), # ToTensor does min-max normalization. \n",
    "]), )\n",
    "\n",
    "test = MNIST('./data', train=False, download=True, transform=transforms.Compose([\n",
    "    transforms.ToTensor(), # ToTensor does min-max normalization. \n",
    "]), )\n",
    "\n",
    "# Create DataLoader\n",
    "dataloader_args = dict(shuffle=True, batch_size=256,num_workers=4, pin_memory=True) if cuda else dict(shuffle=True, batch_size=64)\n",
    "train_loader = dataloader.DataLoader(train, **dataloader_args)\n",
    "test_loader = dataloader.DataLoader(test, **dataloader_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Train]\n",
      " - Numpy Shape: (60000, 28, 28)\n",
      " - Tensor Shape: torch.Size([60000, 28, 28])\n",
      " - min: 0.0\n",
      " - max: 1.0\n",
      " - mean: 0.13066047740240005\n",
      " - std: 0.3081078089011192\n",
      " - var: 0.0949304219058486\n"
     ]
    }
   ],
   "source": [
    "train_data = train.train_data\n",
    "train_data = train.transform(train_data.numpy())\n",
    "\n",
    "print('[Train]')\n",
    "print(' - Numpy Shape:', train.train_data.cpu().numpy().shape)\n",
    "print(' - Tensor Shape:', train.train_data.size())\n",
    "print(' - min:', torch.min(train_data))\n",
    "print(' - max:', torch.max(train_data))\n",
    "print(' - mean:', torch.mean(train_data))\n",
    "print(' - std:', torch.std(train_data))\n",
    "print(' - var:', torch.var(train_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# One hidden Layer NN\n",
    "class Model(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Model, self).__init__()\n",
    "        self.fc = nn.Linear(784, 1000)\n",
    "        self.fc2 = nn.Linear(1000, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view((-1, 784))\n",
    "        h = F.relu(self.fc(x))\n",
    "        h = self.fc2(h)\n",
    "        return F.log_softmax(h)    \n",
    "    \n",
    "    \n",
    "model = Model()\n",
    "if cuda:\n",
    "    model.cuda() # CUDA!\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train\n",
    "\n",
    "Training time:\n",
    "\n",
    "- CPU, about 1 minute and 30 seconds\n",
    "- GPU, about 10 seconds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " Train Epoch: 1/5 [60000/60000 (100%)]\tLoss: 0.282395\t Test Accuracy: 94.8300%\n",
      " Train Epoch: 2/5 [60000/60000 (100%)]\tLoss: 0.239222\t Test Accuracy: 96.7600%\n",
      " Train Epoch: 3/5 [60000/60000 (100%)]\tLoss: 0.055666\t Test Accuracy: 97.3800%\n",
      " Train Epoch: 4/5 [60000/60000 (100%)]\tLoss: 0.124187\t Test Accuracy: 97.7700%\n",
      " Train Epoch: 5/5 [60000/60000 (100%)]\tLoss: 0.008199\t Test Accuracy: 97.8100%\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 5\n",
    "losses = []\n",
    "\n",
    "model.train()\n",
    "for epoch in range(EPOCHS):\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        # Get Samples\n",
    "        data, target = Variable(data), Variable(target)\n",
    "        \n",
    "        if cuda:\n",
    "            data, target = data.cuda(), target.cuda()\n",
    "        \n",
    "        # Init\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Predict\n",
    "        y_pred = model(data) \n",
    "\n",
    "        # Calculate loss\n",
    "        loss = F.cross_entropy(y_pred, target)\n",
    "        losses.append(loss.cpu().data[0])\n",
    "        # Backpropagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        \n",
    "        # Display\n",
    "        if batch_idx % 100 == 1:\n",
    "            print('\\r Train Epoch: {}/{} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch+1,\n",
    "                EPOCHS,\n",
    "                batch_idx * len(data), \n",
    "                len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), \n",
    "                loss.cpu().data[0]), \n",
    "                end='')\n",
    "    # Eval\n",
    "    evaluate_x = Variable(test_loader.dataset.test_data.type_as(torch.FloatTensor()))\n",
    "    evaluate_y = Variable(test_loader.dataset.test_labels)\n",
    "    if cuda:\n",
    "        evaluate_x, evaluate_y = evaluate_x.cuda(), evaluate_y.cuda()\n",
    "\n",
    "    model.eval()\n",
    "    output = model(evaluate_x)\n",
    "    pred = output.data.max(1)[1]\n",
    "    d = pred.eq(evaluate_y.data).cpu()\n",
    "    accuracy = d.sum()/d.size()[0]\n",
    "    \n",
    "    print('\\r Train Epoch: {}/{} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\t Test Accuracy: {:.4f}%'.format(\n",
    "        epoch+1,\n",
    "        EPOCHS,\n",
    "        len(train_loader.dataset), \n",
    "        len(train_loader.dataset),\n",
    "        100. * batch_idx / len(train_loader), \n",
    "        loss.cpu().data[0],\n",
    "        accuracy*100,\n",
    "        end=''))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f95b55620f0>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VNX9x/H3d2YSdmQLiGwBARUrCiIiLnVBRbRal1bc\ntVrbWlurVkWt1mpd21r1516wLnWra10AFUXEBSQouywhbEGWsIVACNnO74+5M0ySWRIIJHf8vJ6H\nhzt3bu49lxs+c+bcc88x5xwiIpJ+Ag1dABER2T0U8CIiaUoBLyKSphTwIiJpSgEvIpKmFPAiImlK\nAS8ikqYU8CIiaUoBLyKSpkINdeAOHTq47Ozshjq8iIgvTZ8+fZ1zLqs22zZYwGdnZ5OTk9NQhxcR\n8SUzW1bbbdVEIyKSphTwIiJpSgEvIpKmFPAiImlKAS8ikqYU8CIiaUoBLyKSpnwX8AtWF/GPDxew\nfsv2hi6KiEij5ruAX1ywhf/7JJd1W0obuigiIo2a7wI+FDAAyioqG7gkIiKNm+8CPiMYLrICXkQk\nOR8HvGvgkoiING6+C/hQMNxEU64avIhIUr4L+GgNvlI1eBGRZHwY8N5N1nLV4EVEkvFdwIcC4SKX\nVyrgRUSS8V3AZ4bCNfhS3WQVEUnKdwEfrcHrJquISFK+C/iMUCTgVYMXEUnGfwEfiDTRqAYvIpKM\n7wI+FFQTjYhIbfgu4KPdJNVEIyKSlA8DPvKgk2rwIiLJ+DbgdZNVRCQ53wV8MGCYaTRJEZFUfBfw\nABmBgNrgRURS8GfAB001eBGRFHwZ8KFgQN0kRURS8GXAZwQDGi5YRCQFnwa8abhgEZEUUga8mXUz\ns4lmNs/M5prZNXG2MTN7xMxyzWyWmQ3cPcUNCwWNctXgRUSSCtVim3LgeufcN2bWCphuZh855+bF\nbHMK0Mf7czjwhPf3bpERDGgsGhGRFFLW4J1zq5xz33jLRcB3QJdqm50BPO/CpgBtzKxzvZfWkxHQ\nTVYRkVTq1AZvZtnAAGBqtbe6ACtiXudT80Og3mSETE+yioikUOuAN7OWwBvAH5xzm3fmYGZ2pZnl\nmFlOQUHBzuwCCE/6oSYaEZHkahXwZpZBONxfdM69GWeTlUC3mNddvXVVOOeeds4Ncs4NysrK2pny\nAuFeNKrBi4gkV5teNAaMAb5zzj2YYLN3gIu93jRDgELn3Kp6LGcVGcGAnmQVEUmhNr1ojgQuAmab\n2Qxv3S1AdwDn3JPAWGAEkAsUA5fVf1F3CAUDbC2t2J2HEBHxvZQB75z7HLAU2zjgt/VVqFQyg6Ze\nNCIiKfjySdZQQE00IiKp+DLgM0IB3WQVEUnBnwEfMHWTFBFJwZcBH1I3SRGRlHwZ8BnBAOWadFtE\nJCnfBnyphgsWEUnKpwGv4YJFRFLxZcCH9CSriEhKvgz4jIBRVuEIP18lIiLx+DPgg+Fiq5lGRCQx\nXwZ8KBLw6iopIpKQLwM+IxgeGqdMXSVFRBLyacCHi12mrpIiIgn5MuBDXg1ebfAiIon5MuAjNXg9\n7CQikphPA141eBGRVHwa8JFeNKrBi4gk4suADwW8JhoFvIhIQr4M+GgTjfrBi4gk5NOA97pJqgYv\nIpKQLwM+0k2yTDV4EZGEfBnwmdGxaFSDFxFJxJcBH1ITjYhISv4M+ICaaEREUvFlwGeGVIMXEUnF\nlwEfqcGrm6SISGK+DHh1kxQRSc3nAa8avIhIIr4M+B3DBasGLyKSiC8DXsMFi4ik5tOA13DBIiKp\n+DTgNVywiEgqvgz4SDfJUt1kFRFJyJcBb2aEAqYavIhIEr4MeAg306gfvIhIYikD3syeMbO1ZjYn\nwfvHmlmhmc3w/txe/8WsKRQ09YMXEUkiVIttngUeBZ5Pss1k59xp9VKiWsoMBtQPXkQkiZQ1eOfc\nZ8CGPVCWOgkFjbJy1eBFRBKprzb4I8xsppmNM7MDE21kZleaWY6Z5RQUFOzSAUOBAGWqwYuIJFQf\nAf8N0MM5dzDwf8DbiTZ0zj3tnBvknBuUlZW1SwfNDAXUBi8iksQuB7xzbrNzbou3PBbIMLMOu1yy\nFNRNUkQkuV0OeDPb28zMWx7s7XP9ru43lXA3SdXgRUQSSdmLxsxeBo4FOphZPvBnIAPAOfckcA7w\nGzMrB7YBI51zuz15M4KmfvAiIkmkDHjn3Hkp3n+UcDfKPSqkbpIiIkn5+ElWdZMUEUnGxwGvbpIi\nIsn4OuA16baISGK+DfhQQDdZRUSS8W3AZ4Q0mqSISDL+DfiARpMUEUnGtwEfCgb0JKuISBK+Dfhw\nLxrV4EVEEvFxwOsmq4hIMj4OeHWTFBFJxrcBHwoaparBi4gk5NuAzwjoJquISDL+DfhggEoHFbrR\nKiISl28DPhQ0AN1oFRFJwLcBnxkMF71cNXgRkbh8G/DRGny5avAiIvH4OODDRdeQwSIi8fk24DOj\nbfBqohERice3AR8KeG3wuskqIhKXbwM+I+Q10agGLyISl38DPqBukiIiyfg24CM3WTUejYhIfL4N\n+AzvJqvGoxERic/HAa+brCIiyfg/4PUkq4hIXL4N+JCaaEREkvJtwGcEdJNVRCQZ/wZ8SN0kRUSS\n8W3AR55kVcCLiMTn24DPVD94EZGkfBvwmvBDRCQ5/we8ukmKiMTl24CPNNFowg8Rkfh8G/DRsWg0\n4YeISFwpA97MnjGztWY2J8H7ZmaPmFmumc0ys4H1X8yaMjThh4hIUrWpwT8LDE/y/ilAH+/PlcAT\nu16s1DLUTVJEJKmUAe+c+wzYkGSTM4DnXdgUoI2Zda6vAiYSCBgBUzdJEZFE6qMNvguwIuZ1vreu\nBjO70sxyzCynoKBglw+cEQyoBi8iksAevcnqnHvaOTfIOTcoKytrl/cXDnjV4EVE4qmPgF8JdIt5\n3dVbt9tlBE29aEREEqiPgH8HuNjrTTMEKHTOraqH/aYUUhONiEhCoVQbmNnLwLFABzPLB/4MZAA4\n554ExgIjgFygGLhsdxW2uoyAqYlGRCSBlAHvnDsvxfsO+G29lagOMkKqwYuIJOLbJ1kBggGjpKyi\noYshItIopazBN2Z5BVvJK9hKRaUjGLCGLo6ISKPi6xp8hJppRERqSouAr9CQwSIiNaRFwJcr4EVE\navB1wO/duimgGryISDy+Dvirj+8NaEx4EZF4fB3wIa/njGrwIiI1+TrgI10jNWSwiEhNvg74yMTb\nqsGLiNTk64APBiLzsirgRUSq83XAqw1eRCQxXwd8tA1evWhERGrwdcCrBi8ikpivA35HDV4BLyJS\nna8DPuTdZFUNXkSkJl8HvPrBi4gk5uuAj/SDr3QKeBGR6nwd8M0zgwBs2FrawCUREWl8fB3wfTu1\nIiNozFu1uaGLIiLS6Pg64DOCAVo1zaCopKyhiyIi0uj4OuABmmUEKS7VxNsiItX5PuBbNAlSvF0B\nLyJSne8DvnlmiK2l5Q1dDBGRRicNAj7INjXRiIjU4PuAb5YRJGfZRnWVFBGpxvcB36VtMwD+N2Nl\nA5dERKRx8X3ARybe/su78ygpU1ONiEiE7wO+Q4sm0eX9bxvfgCUREWlcfB/wAW/AMRERqcr3AQ9w\nWv/ODV0EEZFGJy0Cvm3zzBrrKiud2uRF5ActLQL+mmF9aqy787157H/beMoqNF+riPwwpUXAt2wS\nqrHulWnLARTwIvKDVauAN7PhZrbAzHLNbFSc9y81swIzm+H9uaL+i5pYk1DN0zA0X6uI/LDVrPpW\nY2ZB4DHgRCAfmGZm7zjn5lXb9FXn3NW7oYwpmSXuSaPp/ETkh6o2NfjBQK5zLs85Vwq8Apyxe4u1\n6yKZX64mGhH5gapNwHcBVsS8zvfWVXe2mc0ys9fNrFu9lK4OendsCcBH89YAEPASvixBE83guyfw\nr8/y9kzhREQaQH3dZH0XyHbO9Qc+Ap6Lt5GZXWlmOWaWU1BQUE+HDhvxo70B+OXzOfz8qa+INNok\nqsGvLdrO3WO/q9cyiIg0JrUJ+JVAbI28q7cuyjm33jm33Xs5Gjg03o6cc0875wY55wZlZWXtTHkT\nKtiyPbr89ZIN0eWyOG3wzqldXkTSX20CfhrQx8x6mlkmMBJ4J3YDM4t9lPR0YI9XjasHeVllpbe+\nZg0+XuiLiKSblL1onHPlZnY18AEQBJ5xzs01szuBHOfcO8Dvzex0oBzYAFy6G8scV0W1tvaSsnCw\nV+9FU15RyUtTl+2xcomINJSUAQ/gnBsLjK227vaY5ZuBm+u3aHWTqL/7jPxNPDEpl94dW3Hh4d15\nf/Yq/vJu9R6eIiLpp1YB7weVCQL+trfneEurmZK3niN6td9zhRIRaUBpMVQBQHll6v7uazeXMPf7\nwj1QGhGRhpc2NfjqbfDxrN5cwtL1xXugNCIiDS+NavCpAz5y41VE5IcgbQK+NjX4eLaVVtD31nE8\n+OGCei6RiEjDSpuAH9C97U793JzvCymtqOSRT3JZu7kk2m9+yD0f8+BHC+uziCIie1TaBPw1J9Sc\n9KM2Ymv+g+/5ONrrZvXmEh75eFG9lE1EpCGkTcAHA8aE646p089kBK1G084Hc1fXGMrg0wVr+Wb5\nxjrtW9MFikhDS5uAB+jdsVWdti+rcLw/e1WVdRuLy5j7/eYq6y799zTOevzLWu93w9ZS9r9tPKMn\na7RKEWk4aRXw1R3crU3KbV6aurzGuouf+XqXjrt5WxkAD09QE4+INJy0C/heWS04sV8nlt53KpcO\n7bFT+yguLY+7vrS8kqcmLU7Y/FJYXMb6Ldup8Jp4irbH34+IyJ6QdgH/yfXH8q+LBwGQalTgW0bs\nH3f99vId/eVjh0B4Ycoy7h03n+v+O4M3pueTPep9NhWXRt8fcNeHHPrXCZomUEQahbQL+FipusZ3\na9s87vrYD4atMbX5hyeEu02Onb2a61+bCcBUb+z5sorK6PFihyjeXr7rN1tz126p801eEZG0DvhA\n4rm4AcgMxT/9Lm2aRZc3l5THXY741QvTuX/8fC4aMzW6Ljbg66Or5bAHJ3HW41/yRe465q/enPoH\nRERI84AfcVBnzhvcnV5ZLeK+3yurZdz1Kzdtiy4fed8nKY/zxKeLmZK3Yxap2GETlm8I7+vBDxeQ\ns3THNp8vWses/E1x9zfv+828/HXNm78XjJ7K8IcmpyyPiAikecA3zQhy71kHkdWySdz3e3ZowYzb\nT6z345bGtOG3aZaBc45HPsnlnCe/is4Re+GYqZz+6Bdxf37EI5O5+c3Z9V6u16fn896s7+t9vyLS\nOKXNaJLJtGyy4zQvP6onYz5fEn3dpnlmvR/v9v/NiS5vLC6tctO2963jGHZAx4Q/OzVv/U4ds7C4\njMxQgGaZwYTb/NG7b3Ba/3126hh70hvT86lwjp8P6pZ6YxGJ6wcR8Pef059b35rNiIM6c8YhXbjy\nmF5xe9icN7gbL3+9YpePt7hgKwChgLGxuJRtpVVvtE74bm10uaBoO1mtwt8wlq8v5tynp8R9L5WD\n7/yQ7u2a89mNx+1q8RuFyE1sBbzIzkvrJpqIDi2b8NRFgzjjkC4AdGrdlL33alpju3vP6s9or4tl\nxFtXDeW93x1VY9t/X3ZYyuO2bBrii9z1PDQh8aBlh909gdWFJbw4dRkl1XrcHHb3hJTHiLV8QzEz\nVmwir2AL785UU4zID90PogafykPnHkJnL/CH9esUXT/3LyfTwmvemfjHYznu759G32tbrWnnoC57\nMXtlIYd0a8OMFeGbp1u8XjfPfZV8ku8Hxs/nzW9Xct2JfWu8V9dhkH/62I52/RP7daJpRrjJpvr4\nOvWhtLySvn8ax10//REXDdm5h8pEZPf5QdTgU/npgC4cHmeu1hYxbfc9O7TghcsHR1+3bFK1rfvd\n3x3F0vtOjX5QAByW3a5Wx498IMQbnnhaTM+biDem53PqI5NZvr6YL3PXJdzvqsIS1haVAFUf3ooo\nKimr0qWzrgq9IRnuG/sdD4yfz9ZqT+5u3Foa3SaRykqHc44HP1zAB3NX73RZRKQm1eDjyPnTsLhP\nox7dJ4tu7ZqxYsM2MoPxb2b+9rjejJsTDqr/O38Ag/66o5mlZ4cWLFm3tcbP5MVZFzEypk0+ItI+\nfczfJgJwx0/6xf3ZyDeOpfedSkHR9uj6TcWlHPPAxGi//kk3HMvzXy3jwiE96NkhfpfSeCIfDltL\nK3j808U0zQjy+xP68Oq05XRr25zzR4efDVh09ylkBOPXJXrdMpbzBnePdgtdet+pNbZxzvHClGX8\npP8+tG1R/zfFRdKVavBxdGjZJG4bPcBRvTsAsFezjLjv7xPzkFSHlk0YcdDe0defXP/jKgFWm3b8\n2rjj3XlJ3z/yvk94OOaBqw/mrq7y0NaP//YpYz5fwsinvwLCgTp6ch7Zo95n0F8/InvU+zw0YSHZ\no96v8kGxrdqYPEY49G96Y3Y03AH63DqOL3LXMXlRAdmj3uezhQXAjmEg4vX5jzV7ZSG3/28uN74x\nK+l2u+Lj79ZEv+2IpAsFfB395fQfMemGY9mr+Y6Aj70J26Ta07GPX3BodNlsx6O1R/Rqz5H7dtiN\nJd1h5aZtvD49P/o60UxVazZv5+Wvl9Pz5rH89f3vAFi3JTzWzkPeyJiH3T2B+as3c+bjX7Bmc9VA\nbN4kxOrC+CH51rcro5OpXPzM11RWurjNRvFE5tLduLU0xZY7bCut4OslNZu34qmodFz+XA7n/2tq\nwm1e/no5T01anHQ/m4pLGT05b7fc7xDZGWqiqaPMUIAe7as2Y/yoy17R5eoBD+GbnbEjUC65dwQQ\nDvzT+nfmvVmravzM7rRm8/aE79XmAavI07TVA7GsopItCUbQjP2AAVixsZgFq4tSHgt23CC2akNP\nTMlbz+Dsdtz/wXx6tGvB+Yd3j7532bNfMyVvA5NvPI5u7ZrzwpRlNA0F+FmcbpeR8YKWrNvKbW/P\noaBoO307teS6k/aLbhP5d/nVj/dNWM4bX5/Fh/PWMKB7GwZ0a0sg1VgZu2hV4TZGT17CLSMOILib\njyX+pIDfBWN/fzSLC7ZUWReK09b8r2pdL2Nr8odlt9vjAb+7fJG7jn/UcvLyL3LXc8tbqT9Mxny+\nhNdyws8mGOF/t4/mrWHmik08OjGXX/94X56aFJ5Y5fzDu7Nxayl567ZGh44o8pqiIt8e4ga89w3B\nCI8YCjB+LlUCPmJW/ib6d40/z0DkG83ZT3zFfp1aMe6ao3dryI96YzaTFhZwYr9ODInTSWB3WLim\niKlLNnDqQZ1pp/shjZ6aaHZBv31a85ODd+2p0OrdIJtmJL4k+yS4L9BYTF60jrJaDpWcKNyzR71f\nZSygu96bx3yvph/5XPzl8zk8OjEXgCdjmk0mLljLgLs+4olPd6yrdC7l9ImR5w/Ka9El9T9TljFx\n/lq2lVZQWFzGTa/PYnNJuKdQacy5L1hTxLlPf8XLXy9ne3kFU/LW8+q05Pca6qq8MvzBtCeHpz7p\nn59x29tzGFzHZzQSufWt2fxvxsp62ZfUpIBvYKFg1RrepUN7AvDu1UfVaO55+uJBLLr7FJbedypH\n9m6fNPD/ee7BdG8Xfzjkxu7ZL5bEXW8GZz+ReOrEP/9vLgATvlsTXbe2qIT9bxtfZbuLxkzlmAcm\ncv/4+WSPep9Fa6p+C0tmwZotXPbsNK58IYeD7/yQV3NW8IL3nEP1bqLTlm7k5jdnc+2rMxj59BRu\nemM22aPeJ39jMRDuZlq9m2ptupZGBLxPvIqYNv8vF69jzeaS6IfOta/O4IHx82v8bHlFZcJnLCoq\nHd8s38jE+WurzHdQ5efr+HxGIi9OXc41r8yol33FU7itjBtem5mw6TDdqYlmN4ntPZPMzw7tRl7B\nVp79cikAN568H6f178yPuuzF9Sf15Z6x8/nJwfvQr3NrDtyndbR558UrhgDhABt898c19nvmgK5k\nBANc/dK3SY8/OLsdX8fpax/r85uO46j7J8Z97/bT+nFwtzbc+e5cZuYXpjrduPp2asnCmJD91+T4\nAR87Ymc8yzcU11g3q1qZLhozlcmLws8ORGr6iaZodM5VaU4L7y/8zEJkHxCeAeyN6flxjw/h+QNi\nPTUpjyN7t+fX//kmuu4nB+9D7totfLcqPBz0ExcMZGCPtixZt5Wxs1dx22n92FRcRlFJWXQU1EjA\nR3ojOeeq3BfJ+dMw3vo2XDu+cfj+5CzdwOjJS3j0/AH0vnUcQ3q145Urj6hR3kc/yeWf3tPXh/ds\nx6u/qrkNwOKCLXRt24wmoZpdhldu2sZpj0zmlSuPYL+9W7FwTRFL1m3l5AN3/L/YmYnpN5eU0Swj\nyMI1RWS1bELH1jsqOZ8vWkdxaTknxRzjqUmLeW16Pvt2bMmvk9w/2ZPi/V7tLgr43SBeX+5EmmUG\nueP0Azksux2rCrcRCFj0pu0g70Gpswd24dj94g9Q1rFVU/4wrA+GRf9TvvGboUB4ULEXpyznhAM6\nYmaMm72KnGVVJw4Zfekgnv9yKZcf1YunPltMWUUl781axbL1O8KqQ5zRONs2z2Dovh34xVHhbxzH\n79+JmfmFHNm7PR1bNeWtb1dy9sCuvPFNfo2fjfXEBQM5tEdbBt9T80OqPjxUbV7c2GBOZfXmEoIB\nq1LTjddBprzSRZ9NqI0XpiyLtvVHVB9a4jcvflPl9eCe7aIf1o9fMJARB3WO3lh9bGIuG4tLObV/\n56r7+M/06PK9Y7/jqc/C9yoiYyVNydvA9vIKmoSC3Pb2HIb/aG+2lVZEf48A5q1KPP/ACf+YxGn9\nO/Po+QOB8E3fNZu3c0i3Nrw/63s2Fpdx8kOfccPJ+/G3D8L3ZmL/b2yI6RVVm9BzztH/jg/5ycH7\n8O7M72meGWTencMBWLC6iAu9ORm+ue3E6P2ByLebXenYNDu/kE3bSjm6T1bS7baXV7Dfn8Zz9XG9\n2W/vVtHm2/OensLgnu249sS+zFyxiTMe+4KXfnk4Q/dALzoFfCNR/T8nwMDubZl/1/DocAOJ/GFY\nXxYXbOGfExZyaI+2HNqjbfS9l68cEl2+4PDujJ6cR4eWTRjl9Qpp3TSDq4/vE90PwNJ1xSxbX8x/\nf3UEwUC4Z1CrpqHoDUuAr24+oUq5IoOilZW76H+uHu1TNxGdclD4vMf+/mhGPNI4xrofum97vly8\nniPuTT0XAIRH8tzdYr+J3fDaTK5+6ZvoDGI5yzaSs2wj1/236ofMtKU7Pswj4Q4w9/sd32pOeWhy\n9EG7F6cuqzELWlFJOdmj3uePJ/WlW5wmv/dmreLhkY5nv1zKgx8uYGtpBW9eNZR7xu5oFoqEO8CE\neWt44IP5PPeLwVUCvufNY7nymF5cNKRH3OPAjucuIh+GxTGD+P3syR1Nd2c/8SUT/3hs3H0ksm7L\ndk55eDJjLhlE/65tmL5sI1c8N43rTuzLbV7T338uP5yj+oRD+YbXZrJgTRFvX3Vk9EZ65Pcgcn/o\n6D4d2KtZBl/lreervPVce2JfPpkfHmhw0sKCPRLwaoNv5FKFe0SvDi24ZcT+PHr+gKT7uvr4Powc\n3D3hNgB/+1l/nv/FYAb3bMehPdphZsy+42Q++MMxXDo0m7x7RtQoVyTMM0JGkdf+u3frpjz484Oj\n21SfAzf2HkK/fVpz22nxn8j98NpjyE7xYXHT8Pjz69bV17eeUOt/84hXpu0YgTTZTfL6srW0IuV0\nlMnEfnuIfYo62T7//uHChG3l+94ylrvem8dWL3DPejzxfZIrns9h4ZotHHHvJ9HpLiOe/iyPox+Y\nyH+nreCJTxfjnGN2fiGPfrKI5euLySuo+cT3JO+hudgH95as28qXi9cxenIeeOeU6MvBjBWbyN9Y\nzIR5aygo2s6/v1jK69PzOfuJL9lYXBYNdwjP4VBSVsH4Oat5bXo+s/ILo/c6oOaDf4fc+RE9bx4b\nff1F7rroA4ehPdSt1RrqoYxBgwa5nJycBjm2hNtIt24vp2+nVvWyv5KyCv7+wQKuOLoXFc5x/7j5\n3Hf2QTTPDJE96n0g/PX82+UbOfPxL+nbqSXv/e7oGtMm/vuLJfzl3Xncdlo/7npvHif168TTFw/C\nOcdFY77mc2/snSX3jsDMOPepr5i6ZAMvXD6Yb5ZtqtK8APC743uzdH1xlSaQY/pmcc6hXfn9yzXv\nTyy971Te/nYlf3i1Zpjdd9ZBPDoxl1P7d+bEAzpxzpNfVXn/njMPYuRh3dj/9vFVJn2p7oDOraNt\n7RD+UIg8zCX1K/b+zoTrjmHJumKuenF6ld5eTUKB6EN3qZ5Lef4Xg6vcszluvyx6dmhJRsg4br+O\ncYcWief3x/eO2w23NsxsunNuUOotFfCyB7z5TT5FJeVcMjQb5xyPTczlzIFdq8x9W1tf5q6jSUaA\nQ3uE70/krt3CPz9ayD9+fjBNM4Jc9+oMpi7ZwMpN2xi6b3te+uUQCreV8eSkxdGbqnn3jCAQMKYv\n21ilV87M20+KPqEc+VCC8Fg/Iwd3j1uzf316fnQilch+c9cW8a/PlvBqTtW5BYYd0IlrTuhDZijA\nyQ99xrADOpKzbCMdWzWpcpO5NiJjIkXsv3eraHfSxubcQd2Yv6aImSviT1H5Q3TtsL5cM6zPTv1s\nXQK+Vt8nzWy4mS0ws1wzGxXn/SZm9qr3/lQzy65bkSWdnTWwK5cMzQbCD3ldfXyfnQp3gKG9O0TD\nHaB3x5Y8dsHAaPg+eO4hfH7Tccy64ySevSw8+udezTK4afj+DDugE4d0axNtMz20R1smXPdj/nP5\n4bzxm6FVhp+Yf9dw+ncN3+xu2yIzYbPNOYd2jS5H9tu7YytuOfUAIPzk85hLBvGnUw9g9CWDOKjr\nXuy3dyuW3ncqoy85jOl/OpGOrcJNVY+dP5AFfx3OyMO68Ysje3LgPq2rHCu22+tbVx0ZXT5zQBfG\nXXM0R/Zuz9B9ww88tW+RySPn1Wyui31u47Vfx+8dE7H/3ju+3T1y3gAyqz3E9+IVhyf9eYAD92nN\n/ef0Z8wlNfPozjMOrNKEF8/ZA7vGXd8iycxle8KTFw7cpZ9PMPZevUtZgzezILAQOBHIB6YB5znn\n5sVscxXbHYV8AAAIFElEQVTQ3zn3azMbCZzpnDs32X5Vg5fGrqSsglenreDCIT2SDgWwdnMJZZWu\nyoeWc47Lnp3GJUOzOS5BD6iIVYXbGDN5CTcnGXIg0svk0U8WsWV7BaNO2Z85KwtZs7mEEw7oVGXb\ntZtLaNU0g2aZQcbNXsWS9VtxDoIB49Kh2ex/23gO6daGt397JE98upgnJy3m0fMHsHV7BT/um8U/\nJyzk/MHd6dauOWUVlRQUbadbu+bM+34zYz5fQsDgten5LL3vVBYXbKGsopJpSzYwM7+QrFZN2FRc\nxpeL17FsfTGz7ziJVk3DH5zj56zivzn5PHLegCrTaAJMnL+Wy56dBoS/MT3+6WIuHNKDMwd04aIx\nU7n+pP0oKinnjEP2iQ7jvXTdVo79+6cMO6AjFxzegyl561m4poiJCwqS/nvHiszfYAZXHt2Lpz7L\nI7t9c5aur9rl9YDOrSmrqCR37RZaNw0x646T+fi7NVz+XDjDFv71FPr+aVzcY7RsEqrRD//6E/vy\nuxN2fw2+NgF/BHCHc+5k7/XNAM65e2O2+cDb5iszCwGrgSyXZOcKeJGGsW7LdlpkhpLO37urNhWX\nMn91UZ2GUHDOMW3pRg7LbrtL/cSnLd3A3q2b0rZFJi2bhKo0twE8eWF4AMBJC9dy71n9o8eOHNM5\nx9L1xXRv15x7xn7HCQd0jPZ4+X7TNpqEArRv2QTnHGNnr+akAzuREQywurCEJqEAbVtkUlJWwSMf\nL2JIr/bs27Elqwu30TQjSL/Ordn3lrH85th9ueHknesYUN8Bfw4w3Dl3hff6IuBw59zVMdvM8bbJ\n914v9rZJ2OlYAS8ie8LqwhI2bStl36yWCecl2JOueeVbjtuvIz8d0GWnfr4uAb9H+8Gb2ZXAlQDd\nuyfvqiciUh/23iv+HMwN5eGRibsy17fafJytBGKH4OvqrYu7jddEsxewvvqOnHNPO+cGOecGZWUl\nfypMRER2TW0CfhrQx8x6mlkmMBJ4p9o27wCXeMvnAJ8ka38XEZHdL2UTjXOu3MyuBj4AgsAzzrm5\nZnYnkOOcewcYA7xgZrnABsIfAiIi0oBq1QbvnBsLjK227vaY5RLgZ/VbNBER2RUNf0tZRER2CwW8\niEiaUsCLiKQpBbyISJpqsNEkzawAWJZyw/g6ALWfmqfx0/k0bjqfxu2Hdj49nHO1epCowQJ+V5hZ\nTm0f1fUDnU/jpvNp3HQ+iamJRkQkTSngRUTSlF8D/umGLkA90/k0bjqfxk3nk4Av2+BFRCQ1v9bg\nRUQkBd8FfKr5YRsjM+tmZhPNbJ6ZzTWza7z17czsIzNb5P3d1ltvZvaId46zzGzXJoDcDcwsaGbf\nmtl73uue3ny8ud78vJne+kY/X6+ZtTGz181svpl9Z2ZH+PzaXOv9ns0xs5fNrKmfro+ZPWNma72J\nhCLr6nw9zOwSb/tFZnZJvGPtCQnO52/e79ssM3vLzNrEvHezdz4LzOzkmPV1zz7nnG/+EB7NcjHQ\nC8gEZgL9GrpctSh3Z2Cgt9yK8By3/YAHgFHe+lHA/d7yCGAcYMAQYGpDn0Occ7oOeAl4z3v9X2Ck\nt/wk8Btv+SrgSW95JPBqQ5c9zrk8B1zhLWcCbfx6bYAuwBKgWcx1udRP1wc4BhgIzIlZV6frAbQD\n8ry/23rLbRvR+ZwEhLzl+2POp5+Xa02Anl7eBXc2+xr8F7KO/1BHAB/EvL4ZuLmhy7UT5/E/wpOY\nLwA6e+s6Awu85acIT2we2T66XWP4Q3jSl4+B44H3vP9c62J+YaPXifAw00d4yyFvO2voc4g5l728\nQLRq6/16bboAK7xgC3nX52S/XR8gu1og1ul6AOcBT8Wsr7JdQ59PtffOBF70lqtkWuT67Gz2+a2J\nJvLLG5HvrfMN7yvwAGAq0Mk5t8p7azXQyVtu7Of5EHAjUOm9bg9scs5Fpo6PLW/0XLz3C73tG4ue\nQAHwb6/JabSZtcCn18Y5txL4O7AcWEX433s6/r0+EXW9Ho36OlXzC8LfQqCez8dvAe9rZtYSeAP4\ng3Nuc+x7Lvyx3Oi7NJnZacBa59z0hi5LPQkR/vr8hHNuALCVcBNAlF+uDYDXNn0G4Q+ufYAWwPAG\nLVQ989P1SMXMbgXKgRd3x/79FvC1mR+2UTKzDMLh/qJz7k1v9Roz6+y93xlY661vzOd5JHC6mS0F\nXiHcTPMw0MbC8/FC1fLWar7eBpQP5DvnpnqvXycc+H68NgDDgCXOuQLnXBnwJuFr5tfrE1HX69HY\nrxNmdilwGnCB96EF9Xw+fgv42swP2+iYmRGe1vA759yDMW/FzmV7CeG2+cj6i70eAkOAwpivpw3K\nOXezc66rcy6b8L//J865C4CJhOfjhZrn0mjn63XOrQZWmNl+3qoTgHn48Np4lgNDzKy593sXOR9f\nXp8Ydb0eHwAnmVlb71vNSd66RsHMhhNu5jzdOVcc89Y7wEivd1NPoA/wNTubfQ19M2UnblaMINwL\nZTFwa0OXp5ZlPorwV8pZwAzvzwjCbZ0fA4uACUA7b3sDHvPOcTYwqKHPIcF5HcuOXjS9vF/EXOA1\noIm3vqn3Otd7v1dDlzvOeRwC5HjX523CvS58e22AvwDzgTnAC4R7ZPjm+gAvE75/UEb4G9blO3M9\nCLdt53p/Lmtk55NLuE09kgdPxmx/q3c+C4BTYtbXOfv0JKuISJryWxONiIjUkgJeRCRNKeBFRNKU\nAl5EJE0p4EVE0pQCXkQkTSngRUTSlAJeRCRN/T8AZbLAbRIGGAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f95c3234d68>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 97.81\n"
     ]
    }
   ],
   "source": [
    "evaluate_x = Variable(test_loader.dataset.test_data.type_as(torch.FloatTensor()))\n",
    "evaluate_y = Variable(test_loader.dataset.test_labels)\n",
    "if cuda:\n",
    "    evaluate_x, evaluate_y = evaluate_x.cuda(), evaluate_y.cuda()\n",
    "\n",
    "model.eval()\n",
    "output = model(evaluate_x)\n",
    "pred = output.data.max(1)[1]\n",
    "d = pred.eq(evaluate_y.data).cpu()\n",
    "accuracy = d.sum()/d.size()[0]\n",
    "\n",
    "print('Accuracy:', accuracy*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
